import os
import numpy as np
import math


def mkdir(path):
    ''' recursively create given path '''
    upper, name = os.path.split(path)
    if not os.path.exists(upper):
        mkdir(upper)
    # if path already exists, do not create. Otherwise, error will occur
    if not os.path.exists(path):
        os.mkdir(path) 

def most_common(lst):
    ''' find the most common element in a list, being used to unify the signal cycle length '''
    return max(set(lst), key=lst.count)

def LC_Control(scenario, links_obj, link_groups, dLC, cmode):
    ''' Apply Lane Change control for specific scenario
    Parameters
    -------------
        scenario: dict
            the scenario dictionary
        links_obj: links object of VISSIM
            the links object
        link_groups: list
            the list of link groups
        LC_distance: int 
            the distance of LC controlled section (m)
        cmode: int
            control mode

    '''
    mainlinks = []
    for group in link_groups:
        for mainlink in group["MAINLINE"]:
            mainlinks.append(mainlink)
    idxBlockedLink = mainlinks.index(scenario["link"])
    
    # iteratively add upstream links until the total length exceeds dLC, stop if the next upstream link's length is over dLC
    i = 0
    links_LC = []
    links_LC_length = 0
    while links_LC_length < dLC:
        i += 1
        if links_obj.ItemByKey(mainlinks[idxBlockedLink - i]).AttValue("Length2D") > dLC:
            break
        links_LC.append(mainlinks[idxBlockedLink - i])
        links_LC_length += links_obj.ItemByKey(mainlinks[idxBlockedLink - i]).AttValue("Length2D")
        
    lanes = scenario['lane']
    for link in links_LC:
        for Lane in links_obj.ItemByKey(link).Lanes:
            if Lane.AttValue('Index') in lanes:
                Lane.SetAttValue('BlockedVehClasses', '10,20')
    if cmode == 2:
        for Lane in links_obj.ItemByKey(links_LC[0]).Lanes:
            if Lane.AttValue('Index')-1 in lanes:
                Lane.SetAttValue('BlockedVehClasses', '10,20')


def AlineaQ(rmRate, density, Qlen, demand, Qcap, rho):
    ''' Apply Ramp Metering control using modified ALINEAQ algorithm '''
    rmMIN = 400
    rmMAX = 1800
    alpha = 1
    beta = 4
    cap_bound = 0.4
    rm_den = rmRate - alpha * (density - rho)
    rm_Q = demand + beta * (Qlen - cap_bound * Qcap)
    return max(rmMIN, min(rmMAX, max(rm_den, rm_Q)))


def RLRM(rmRate, density, Qlen, demand, Qcap, rho, vsl):
    ''' Apply Ramp Metering control using reinforcement learning (naive) '''
    rmMIN = demand*0.7
    rmMAX = demand*1.2
    rm = rmRate
    if Qlen > (Qcap*0.7):
        rm = rmMAX
    else:
        rm = demand - (density - rho)*vsl
    
    print([demand, rm, vsl])
    return max(rmMIN, min(rmMAX, rm))


def VSL_FBL(params, density, flow, vsl, err_accum, rampFlow):
    ''' Apply robust Variable Speed Limit control that handles uncertainties '''
    startSection = params["start_end"][0]    # startSection: (int) the first section controlled with VSL            
    endSection = params["start_end"][1]      # endSection: (int) the discharging section. 
    Nsec_controlled = endSection - startSection + 1     # Nsec_controlled is the number of sections under control

    # incorporate uncertainties in w
    w = params["w"] * (1 + params["perturbations"][2])
    rho_j = params["C"]/params["vf"] + params["C"]/w

    rho_e = [params["rho_star"]] * (Nsec_controlled + 1)   
    rho_true = density[startSection: (endSection + 2)]
    q_true = flow[(startSection + 1): (endSection + 3)]
    
    # incorporate measurement uncertainties
    rho = [(1 + params["perturbations"][1]) * x for x in rho_true]
    q = [(1 + params["perturbations"][0]) * x for x in q_true]
    
    # get ramp flows 
    onFlow = [0.0] * Nsec_controlled
    offFlow = [0.0] * Nsec_controlled
    for i in range(Nsec_controlled):
        for onramp in params["link_groups"][i + startSection + 1]['ONRAMP']:
            onFlow[i] += rampFlow[onramp]
        for offramp in params["link_groups"][i + startSection + 1]['OFFRAMP']:
            offFlow[i] += rampFlow[offramp]

    # compute errors in densities
    x = np.subtract(rho, rho_e)
    for i in range(len(x)):
        err_accum[i] += x[i]
    
    # we need N+1 measurements to produce N VSL commands
    Lambda1 = [20.0]*(Nsec_controlled)
    Lambda2 = [2.0]*(Nsec_controlled)
    c = [-500.0]*(Nsec_controlled)
    v = [0.0]*(Nsec_controlled)
    qv = [0.0]*(Nsec_controlled)
    for i in range(Nsec_controlled):
        qv[i] = q[i+1] + offFlow[i] - onFlow[i] - Lambda1[i]*x[i+1] - Lambda2[i]*(err_accum[i+1] + c[i])
    
    for i in range(Nsec_controlled):
        if i == 0:
            v[i] = qv[i] * w / (w*rho_j - qv[i])
        elif rho[i] == 0:
            v[i] == params["vf"]
        else:
            v[i] = qv[i] / rho[i]

    for i in range(Nsec_controlled):
        v[i] = round(v[i] * 0.1) * 10
        if v[i] <= vsl[i + startSection]:
            if i == 0:
                v[i] = max(v[i], vsl[i + startSection] - 10, params["min_values"][0])
            else:
                v[i] = max(v[i], vsl[i + startSection] - 10, params["min_values"][1])
        else:
            v[i] = min(v[i], vsl[i + startSection] + 10, params["vf"])
        vsl[i + startSection] = v[i]

    return vsl, err_accum


def VSL_NoCtrl(params, vsl):
    ''' Add speed limit constraints for open-loop scenarios to mimic real-world traffic patterns '''
    ''' Active if VSL mode is set to 9 '''
    ''' Hard coding VSL commands'''
    if (params["mode"] == 9):
        startSection = params["start_end"][0]    # startSection: (int) the first section controlled with VSL            
        endSection = params["start_end"][1]      # endSection: (int) the discharging section. 
        Nsec_controlled = endSection - startSection + 1     # Nsec_controlled is the number of sections under control        
        
        vsl = [100, 90, 70, 50, 30, 100, 100, 100]
        return vsl
    else:
        return vsl

def Arterial_TSC(signalControllers, params, Cycs, inci_stage=0):
    prog_crit = params["prog_crit_defcyc"]
    if params["TSCmode"] == 1:
        prog_crit = params["prog_crit_optcyc"]
        
    for i in range(len(params["intSC"])):
        SCid = params["intSC"][i]
        controller = signalControllers.ItemByKey(SCid)
        if SCid in params["intSC_crit"]:
            controller.SetAttValue('ProgNo', prog_crit[inci_stage])
        else:
            progNo = int(params["default_cyc"] / 20 - 2)
            if params["TSCmode"] == 1:
                progNo = int(Cycs[i] / 20 - 2)
            controller.SetAttValue('ProgNo', progNo)
        


def setVehRouting(net, params, inci_flag=True):
    ''' Change static vehicle routings, hard coding using routing id '''
    relflow = params["appRF"]        # original relative flows for left/straight/right
    relflow_Onramp = params["onrampRF"]    # original relative flows for onramp/arterial
    smr = params["std-mean-ratio"]
    vehroutes = net.VehicleRoutingDecisionsStatic
    drivbehaviors = net.DrivingBehaviors
    for drivbehavior in drivbehaviors:
        drivbehavior.SetAttValue('DiffusTm', params["diffusionTime"])
    # set changed relative flows due to the incident or event when flag is on
    if inci_flag:
        for j in range(len(params["changedID"])):
            vehroute = vehroutes.ItemByKey(params["changedID"][j])
            for i in range(vehroute.VehRoutSta.Count):
                if smr > 0:
                    vehroute.VehRoutSta.ItemByKey(i+1).SetAttValue('RelFlow(1)', np.random.normal(params["changedRF"][j][i],params["changedRF"][j][i]*smr))
                else:
                    vehroute.VehRoutSta.ItemByKey(i+1).SetAttValue('RelFlow(1)', params["changedRF"][j][i])
    # set relative flows for all approaches at each intersection
    else:    
        for j in params["appID"]:
            vehroute = vehroutes.ItemByKey(j)
            for i in range(vehroute.VehRoutSta.Count):
                if smr > 0:
                    vehroute.VehRoutSta.ItemByKey(i+1).SetAttValue('RelFlow(1)', np.random.normal(relflow[i],relflow[i]*smr))
                else:
                    vehroute.VehRoutSta.ItemByKey(i+1).SetAttValue('RelFlow(1)', relflow[i])
        # set relative flows for each on-ramp
        for j in params["onrampID"]:
            vehroute = vehroutes.ItemByKey(j)
            for i in range(vehroute.VehRoutSta.Count):
                if smr > 0:
                    vehroute.VehRoutSta.ItemByKey(i+1).SetAttValue('RelFlow(1)', np.random.normal(relflow_Onramp[i],relflow_Onramp[i]*smr))
                else:
                    vehroute.VehRoutSta.ItemByKey(i+1).SetAttValue('RelFlow(1)', relflow_Onramp[i])
    

def getEstDemands(demands, params_Routes, inci_flag=False):
    ''' Estimate demands for on-ramps and intersections '''
    Nint = (len(demands) - 3) // 2      # number of intersections
    onramp_idx = [0,1,3,4,5]            # the index of intersections associated with on-ramps
    offramp_idx = [1,2,3,4,6]           # the index of intersections associated with off-ramps
    offramp_ratios = [0.1,0.05,0.03,0.01,0.01]
    offramp_flows = [demands[0]*r for r in offramp_ratios]
    regrelflow = params_Routes["appRF"]
    y_Sapp_ls = [regrelflow] * Nint
    y_Wapp_ls = [regrelflow] * Nint
    y_Napp_ls = [regrelflow] * Nint
    y_Eapp_ls = [regrelflow] * Nint
    onramp_ratio = params_Routes["onrampRF"][0]
    y_onramp_ls = [onramp_ratio] * len(onramp_idx)
    if inci_flag:
        for i in range(len(params_Routes["changedPos"])):
            pos = params_Routes["changedPos"][i]
            if pos == "AppS": y_Sapp_ls[params_Routes["changedIdx"][i]] = params_Routes["changedRF"][i]
            if pos == "AppW": y_Wapp_ls[params_Routes["changedIdx"][i]] = params_Routes["changedRF"][i]
            if pos == "AppN": y_Napp_ls[params_Routes["changedIdx"][i]] = params_Routes["changedRF"][i]
            if pos == "AppE": y_Eapp_ls[params_Routes["changedIdx"][i]] = params_Routes["changedRF"][i]
            if pos == "Ramp": y_onramp_ls[params_Routes["changedIdx"][i]] = params_Routes["changedRF"][i][0]
            
    dS = demands[1]
    dWs = demands[2:(2+Nint)]
    dN = demands[2+Nint]
    d_Eapp_ls = demands[(3+Nint):]
    d_Sapp_ls = []
    d_Wapp_ls = []
    d_Napp_ls = []
    d_onramp_ls = []
    for j in range(Nint):
        d_Wapp = dWs[j]
        if j in offramp_idx: 
            d_Wapp += offramp_flows[0]
            offramp_flows.pop(0)
        d_Wapp_ls.append(d_Wapp)        
    for j in range(Nint):
        if j == 0:
            d_Sapp_ls.append(dS)
            d_Napp_ls.append(dN)
        else:
            d_Sapp_ls.append(d_Wapp_ls[j-1]*y_Wapp_ls[j-1][0]+d_Sapp_ls[j-1]*y_Sapp_ls[j-1][1]+d_Eapp_ls[j-1]*y_Eapp_ls[j-1][2])
            d_Napp_ls.insert(0,d_Wapp_ls[Nint-j]*y_Wapp_ls[Nint-j][2]+d_Napp_ls[0]*y_Napp_ls[Nint-j][1]+d_Eapp_ls[Nint-j]*y_Eapp_ls[Nint-j][0])
    for j in range(Nint):
        if j in onramp_idx:
            d_onramp_ls.append((d_Sapp_ls[j]*y_Sapp_ls[j][0]+d_Eapp_ls[j]*y_Eapp_ls[j][1]+d_Napp_ls[j]*y_Napp_ls[j][2])*y_onramp_ls[0])
            y_onramp_ls.pop(0)
    
    return [d_Sapp_ls,d_Wapp_ls,d_Napp_ls,d_Eapp_ls], d_onramp_ls

# For 4-phase plan, cap is manually tuned
def getOptCycle(d_intersecs, cap=7200, alpha=76.9, beta=-186.2, LT=16):
    ''' Compute optimal cycle length (s) based on given demands'''
    OptCycs = []
    for i in range(len(d_intersecs[0])):
        y_sum = 0
        for d_app in d_intersecs:
            y_app = d_app[i] / cap
            y_sum += y_app
        OptCyc = 180
        if y_sum < 1:
            OptCyc = min(180, max(60, alpha*np.log(LT/(1-y_sum)) + beta))            
        OptCycs.append(int(np.round(OptCyc * 0.05) * 20))
    return OptCycs

# For 5-phase plan, cap is manually tuned
def getUniCycle(d_intersecs, cap=10800, alpha=136.8, beta=-357.7, LT=16):
    ''' Compute the unified cycle length (s) based on given demands'''
    OptCycs = []
    for i in range(len(d_intersecs[0])):
        y_sum = 0
        for d_app in d_intersecs:
            y_app = d_app[i] / cap
            y_sum += y_app
        OptCyc = 180
        if y_sum < 1:
            OptCyc = min(180, max(60, alpha*np.log(LT/(1-y_sum)) + beta))            
        OptCycs.append(int(np.round(OptCyc * 0.05) * 20))
    # Take the most common opt cycle as the unified cycle
    UniCycs = [most_common(OptCycs)] * (len(OptCycs))
    return UniCycs

def getOffsets(d_intersecs, params_Routes, distances, Ta=6, b1=0.01, b2=-4.1, qs=1, va=60):
    ''' Compute offsets btw adjacent intersections using estimated approach demand, assuming queue-demand relationship is linear '''
    Offsets_SB = []
    Offsets_NB = []
    d_Sapp_ar = np.array(d_intersecs[0])
    d_Napp_ar = np.array(d_intersecs[2])
    b2_ar = np.array([b2]*(len(distances)+1))
    # Southbound queue estimation
    Nq_SB_ar = b1*d_Sapp_ar + b2_ar
    # Northbound queue estimation
    Nq_NB_ar = b1*d_Napp_ar + b2_ar
    # Southbound offsets computation
    for i in range(len(distances)):
        Nq = Nq_SB_ar[i+1]    
        Tq = Nq/qs + Ta/2
        if Nq < qs*Ta/2:
            Tq = math.sqrt(2*Nq*Ta/qs)
        
        To = Ta + distances[i]*3.6/va - Tq
        To_rounded = int(np.round(To*0.2)*5)
        Offsets_SB.append(max(To_rounded, 0))
    # Northbound offsets computation
    for i in range(len(distances)):
        Nq = Nq_NB_ar[i]    
        Tq = Nq/qs + Ta/2
        if Nq < qs*Ta/2:
            Tq = math.sqrt(2*Nq*Ta/qs)
        
        To = Ta + distances[i]*3.6/va - Tq
        To_rounded = int(np.round(To*0.2)*5)
        Offsets_NB.append(max(To_rounded, 0))
    return Offsets_SB, Offsets_NB

def Arterial_TSC_UniCyc(signalControllers, params, Cycs, Offsets):
    ''' Assign signal programs based on the cycle length and the offset'''
    seq_offsets = [params["default_offset"]] * len(Cycs)
    if params["TSCmode"] == 0:
        Cycs = [params["default_cyc"]] * len(Cycs)
    elif params["TSCmode"] == 2:
        # convert individual offsets into sequential offsets where the first offset is 0
        seq_offsets = [0]
        for i in range(len(Offsets)):
            seq_offsets.append(seq_offsets[i]+Offsets[i])
    elif params["TSCmode"] in [3,4,5]:
        # only 2 options for two-way offsets, 0 or half cycle
        seq_offsets = [0]
        cyc = int(Cycs[0])
        halfcyc = int(cyc/2)
        for i in range(len(Offsets)):
            offset = Offsets[i] % cyc
            if abs(offset-halfcyc) < int(halfcyc/2): 
                seq_offsets.append(seq_offsets[i]+halfcyc)
            else:
                seq_offsets.append(seq_offsets[i])
    for i in range(len(params["intSC"])):
        SCid = params["intSC"][i]
        controller = signalControllers.ItemByKey(SCid)
        cyc = int(Cycs[i])
        offset = int(seq_offsets[i] % cyc)
        if cyc == 60:
            progNo = int(offset/5) + 1
        elif cyc == 80:
            progNo = int(offset/5) + 13
        elif cyc == 100:
            progNo = int(offset/5) + 29
        elif cyc == 120:
            progNo = int(offset/5) + 49
        else:
            progNo = int(cyc/20) + 66
            
        controller.SetAttValue('ProgNo', progNo)
     
    return seq_offsets

def getTimeDeviations(cyc, seq_offsets, Offsets_SB, Offsets_NB):
    ''' Compute time deviations for 2-way offsets'''
    time_deviations = [[], []]
    for i in range(len(Offsets_SB)):
        offset_ideal = Offsets_SB[i] % cyc
        offset_actual = (seq_offsets[i+1] - seq_offsets[i]) % cyc
        deviation = offset_actual - offset_ideal
        if deviation > cyc/2:
            deviation -= cyc
        elif deviation < -cyc/2:
            deviation += cyc
        time_deviations[0].append(deviation)
    for i in range(len(Offsets_NB)):
        offset_ideal = Offsets_NB[i] % cyc
        offset_actual = cyc - ((seq_offsets[i+1] - seq_offsets[i]) % cyc)
        deviation = offset_actual - offset_ideal
        if deviation > cyc/2:
            deviation -= cyc
        elif deviation < -cyc/2:
            deviation += cyc
        time_deviations[1].append(deviation)
    print(time_deviations)
    return time_deviations

def ArterialSpeedAdvisory(VSLs, time_deviations, ASA_2way, distances, va):
    ''' Apply arterial speed advisory to eliminate time deviations'''
    for i in range(len(distances)):
        #Southbound speed advisory
        vak = distances[i]*va / (time_deviations[0][i]*va/3.6 + distances[i])
        vak_rounded = int(math.ceil(vak))
        if vak < 50 or vak > 75:
            vak_rounded = int(round(vak*0.2)*5)
        VSL = VSLs.ItemByKey(ASA_2way[0][i*2])
        VSL.SetAttValue("DesSpeedDistr(10)", vak_rounded)   # car
        VSL.SetAttValue("DesSpeedDistr(20)", vak_rounded)   # HGV
        VSL = VSLs.ItemByKey(ASA_2way[0][i*2+1])
        VSL.SetAttValue("DesSpeedDistr(10)", vak_rounded)   # car
        VSL.SetAttValue("DesSpeedDistr(20)", vak_rounded)   # HGV
        print(vak_rounded)
        #Northbound speed advisory
        vak = distances[i]*va / (time_deviations[1][i]*va/3.6 + distances[i])
        vak_rounded = int(math.ceil(vak))
        if vak < 50 or vak > 75:
            vak_rounded = int(round(vak*0.2)*5)
        VSL = VSLs.ItemByKey(ASA_2way[1][-i*2-1])
        VSL.SetAttValue("DesSpeedDistr(10)", vak_rounded)   # car
        VSL.SetAttValue("DesSpeedDistr(20)", vak_rounded)   # HGV
        VSL = VSLs.ItemByKey(ASA_2way[1][-i*2-2])
        VSL.SetAttValue("DesSpeedDistr(10)", vak_rounded)   # car
        VSL.SetAttValue("DesSpeedDistr(20)", vak_rounded)   # HGV
        
